/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License. 
*/

#include "kernel_operator.h"
#include "../../include/batch_cal_valid_len.h"
#include "../../include/batch_padding.h"
#include "../../include/batch_matmul.h"
#include "../../include/batch_epilogue.h"


/**
 * @brief: 有效长度计算
 * @param [in] layoutA: A矩阵排布格式
 * @param [in] layoutB: B矩阵排布格式
 * @param [in] zeroPaddingM: A、C矩阵零填充后的M维度
 * @param [in] zeroPaddingN: B、C矩阵零填充后的N维度
 * @param [in] zeroPaddingK: A、B矩阵零填充后的K维度
 * @param [in] batchCount: 批量矩阵乘的batch数
 * @param [in] d_maskA: 掩码矩阵
 * @param [in] d_APointer: A矩阵指针数组
 * @param [out] d_validM: 每批矩阵乘的A、C矩阵M维度有效长度数组
 * @param [out] d_validN: 每批矩阵乘的B、C矩阵N维度有效长度数组
 * @param [out] d_validK: 每批矩阵乘的A、B矩阵N维度有效长度数组
*/
[aicore] inline __attribute__((always_inline)) void CalValidLenHeadQKTVP(
    layoutType layoutA, 
    layoutType layoutB, 
    uint32_t zeroPaddingM, 
    uint32_t zeroPaddingN, 
    uint32_t zeroPaddingK, 
    uint32_t batchCount, 
    __gm__ half* d_maskA, 
    __gm__ half **d_APointer, 
    __gm__ uint32_t *d_validM, 
    __gm__ uint32_t *d_validN, 
    __gm__ uint32_t *d_validK
){

    BatchCalValidMWithMask<half>(
        zeroPaddingM, 
        batchCount, 
        d_validM, 
        d_maskA
    );

    AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
    AscendC::CrossCoreWaitFlag(0); 

    PadValidLen(
        batchCount, 
        d_validN, 
        zeroPaddingN
    );

    AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
    AscendC::CrossCoreWaitFlag(0); 

    PadValidLen(
        batchCount, 
        d_validK, 
        zeroPaddingK
    );

}


/**
 * @brief: 
 * @param [in] layoutA: A矩阵排布格式
 * @param [in] layoutB: B矩阵排布格式
 * @param [in] zeroPaddingM: A、C矩阵零填充后的M维度
 * @param [in] zeroPaddingN: B、C矩阵零填充后的N维度
 * @param [in] zeroPaddingK: A、B矩阵零填充后的K维度
 * @param [in] batchCount: 批量矩阵乘的batch数
 * @param [in] d_maskA: 掩码矩阵
 * @param [out] d_validM: 每批矩阵乘的A、C矩阵M维度有效长度数组
 * @param [out] d_validN: 每批矩阵乘的B、C矩阵N维度有效长度数组
 * @param [out] d_validK: 每批矩阵乘的A、B矩阵N维度有效长度数组
 * @param [in] alpha: alpha*AB+beta*C
 * @param [in] d_APointer: 每批矩阵乘的零填充A矩阵首地址数组
 * @param [in] d_BPointer: 每批矩阵乘的零填充B矩阵首地址数组
 * @param [in] beta:  alpha*AB+beta*C
 * @param [out] d_CPointer: 每批矩阵乘的零填充B矩阵首地址数组
 * @param[in] d_isAPadding: 各个batch A矩阵是否需要padding
 * @param[in] d_isBPadding: 各个batch B矩阵是否需要padding
 * @param[in] d_APointerPadding: 需padding的A矩阵新的device内存空间
 * @param[in] d_BPointerPadding: 需padding的B矩阵新的device内存空间
 * @param [in] d_AicAivWorkspacePointer: Aic Aiv 同步的GM空间首地址数组
 * @param [in] fftsAddr: 跨核同步需要的地址
 * @param [in] isAlpha1Beta0: 是否有 alpha==1.0 && beta==0.0
 */

 extern "C" __global__ [aicore] void LLMsGEMM_batch_QKTVP_device (
    layoutType layoutA, 
    layoutType layoutB, 
    uint32_t zeroPaddingM, 
    uint32_t zeroPaddingN, 
    uint32_t zeroPaddingK, 
    uint32_t batchCount, 
    __gm__ half* d_maskA, 
    __gm__ uint32_t*  d_validM, 
    __gm__ uint32_t*  d_validN, 
    __gm__ uint32_t*  d_validK, 
    half alpha, 
    __gm__ half**  d_APointer, 
    __gm__ half**  d_BPointer, 
    half beta,
    __gm__ half**  d_CPointer, 
    __gm__ uint8_t *d_isAPadding, 
    __gm__ uint8_t *d_isBPadding, 
    __gm__ half** d_APointerPadding, 
    __gm__ half** d_BPointerPadding, 
    uint8_t paddingDirA, 
    uint8_t paddingDirB,  
    __gm__ half** d_AicAivWorkspacePointer, 
    uint64_t fftsAddr,  
    uint8_t isAlpha1Beta0 
) {

#if __DAV_C220_CUBE__

    AscendC::SetSyncBaseAddr(fftsAddr);

    AscendC::SetAtomicNone();
    AscendC::SetLoadDataPaddingValue<uint64_t>((uint64_t)0);
    AscendC::SetFixpipeNz2ndFlag(1, 0, 0);

    AscendC::CrossCoreWaitFlag(1);

    BatchMatmul<L1M0, L1N0, L1K0, WORKSPACENUM>(
        layoutA, 
        layoutB, 
        zeroPaddingM, 
        zeroPaddingN, 
        zeroPaddingK, 
        batchCount,
        d_validM, 
        d_validN, 
        d_validK, 
        alpha, 
        d_APointer, 
        d_BPointer, 
        beta, 
        d_CPointer, 
        d_isAPadding, 
        d_isBPadding, 
        d_APointerPadding, 
        d_BPointerPadding, 
        paddingDirA, 
        paddingDirB, 
        d_AicAivWorkspacePointer,
        isAlpha1Beta0
    );

#elif __DAV_C220_VEC__

    AscendC::SetSyncBaseAddr(fftsAddr);

    AscendC::SetAtomicNone();
    AscendC::SetMaskNorm();
    AscendC::SetVectorMask<half, AscendC::MaskMode::NORMAL>( 0xfffffffffffffff, 0xfffffffffffffff );

    CalValidLenHeadQKTVP(
        layoutA, 
        layoutB, 
        zeroPaddingM, 
        zeroPaddingN, 
        zeroPaddingK, 
        batchCount, 
        d_maskA, 
        d_APointer, 
        d_validM, 
        d_validN, 
        d_validK
    );
    
    AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
    AscendC::CrossCoreWaitFlag(0); 

    BatchMatrixPadding<L1M0, L1K0>(
        layoutA, 
        zeroPaddingM, 
        zeroPaddingK, 
        batchCount, 
        d_validM, 
        d_validK, 
        d_APointer, 
        d_isAPadding, 
        d_APointerPadding, 
        paddingDirA
    );

    AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
    AscendC::CrossCoreWaitFlag(0); 

    BatchMatrixPadding<L1K0, L1N0>(
        layoutB, 
        zeroPaddingK, 
        zeroPaddingN, 
        batchCount, 
        d_validK, 
        d_validN, 
        d_BPointer, 
        d_isBPadding, 
        d_BPointerPadding, 
        paddingDirB
    );

    AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
    AscendC::CrossCoreWaitFlag(0); 

    AscendC::CrossCoreSetFlag<2, PIPE_MTE3>(1);

    BatchMatmulEpilogue<L1M0, L1N0, L1K0, WORKSPACENUM>(
        zeroPaddingM, 
        zeroPaddingN, 
        batchCount,
        d_validM, 
        d_validN, 
        alpha, 
        beta, 
        d_CPointer, 
        d_AicAivWorkspacePointer,
        isAlpha1Beta0
    );

#endif

}
    
